import os
import argparse
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader, Subset
from utils import CIFAR10Generation
from transformers import AutoProcessor, CLIPModel


class ImgFeatureExtractor():
    def __init__(self, args):
        self.device = "cuda"
        self.args = args
        self.model = "VIT_L"

    def extract_feature(self):
        bsz = 16
        if not os.path.exists(self.args.embedding_path):
            os.makedirs(self.args.embedding_path)
        dst_train = CIFAR10Generation(self.args.cifar10_path)
        dataloader = self.get_cifar10_loader(dst_train, bsz, num_chunks=1)
        # Model
        model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(self.device)
        processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14")
        
        for i, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
            indices, _ = batch
            out_paths = [os.path.join(self.args.embedding_path, f'{indices[idx]:05d}.pt') for idx in range(len(indices))]
            if os.path.exists(out_paths[-1]):
                continue
            if self.model in ["VIT_L"]:
                images = [dst_train.cifar10[indices[idx]][0] for idx in range(len(indices))]
                inputs = processor(images=images, return_tensors="pt").to(self.device)
                image_features = model.get_image_features(**inputs).to(torch.float16)
            
            for idx in range(len(indices)):
                torch.save(image_features[idx], out_paths[idx])
                
    def get_cifar10_loader(self, dst_train, bsz, num_chunks=1):
        chunk_size = len(dst_train) // num_chunks
        chunk_index = self.args.index
        if chunk_index == num_chunks-1:
            subset_indices = range(chunk_index*chunk_size, len(dst_train))
        else:
            subset_indices = range(chunk_index*chunk_size, (chunk_index+1)*chunk_size)
        subset_dataset = Subset(dst_train, indices=subset_indices)
        dataloader = DataLoader(subset_dataset, batch_size=bsz, shuffle=False, num_workers=2)
        return dataloader


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--index", default=0, type=int, help="split task")
    parser.add_argument("--cifar10_path", default="datasets", type=str, help="path to cifar10")
    parser.add_argument("--embedding_path", required=True, type=str, help="path to save embeddings")
    args = parser.parse_args()
    return args

def main():
    args = get_args()
    extractor = ImgFeatureExtractor(args)
    extractor.extract_feature()


if __name__ == "__main__":
    torch.backends.cudnn.benchmark = True
    main()